Skip to content

feat: SimulatorInterferometer(use_jax=True) + xp-aware preprocess Gaussian noise#336

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/simulator-interferometer-use-jax
May 24, 2026
Merged

feat: SimulatorInterferometer(use_jax=True) + xp-aware preprocess Gaussian noise#336
Jammy2211 merged 1 commit into
mainfrom
feature/simulator-interferometer-use-jax

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

PR 3 of Phase 2 (z_features/jax_user_intro.md). Adds use_jax=True to aa.SimulatorInterferometer and routes the complex-Gaussian visibility noise pipeline through jax.random on the JAX path. Also fixes the pre-existing signature asymmetry — via_image_from now accepts xp=None matching aa.SimulatorImaging.

Eager JAX path works: simulator.via_image_from(image) with use_jax=True returns an Interferometer with jax.Array visibilities. @jax.jit wrap is currently blocked by the same Array2D.native jit-incompatibility flagged in the SimulatorImaging PR — affects all simulators until a separate slim/native reshape refactor lands.

Companion changes in PyAutoLens (autolens/interferometer/simulator.py) and PyAutoGalaxy (autogalaxy/interferometer/simulator.py) forward xp from the new _xp property on the parent.

API Changes

  • Added: aa.SimulatorInterferometer(..., use_jax=False) constructor flag.
  • Added: aa.SimulatorInterferometer._xp property.
  • Changed signature: aa.SimulatorInterferometer.via_image_from(image, xp=None) — fixed asymmetry with SimulatorImaging by adding the xp parameter. xp=None defaults to self._xp.
  • Changed signature: preprocess.gaussian_noise_via_shape_and_sigma_from(..., xp=np), data_with_gaussian_noise_added(..., xp=np), data_with_complex_gaussian_noise_added(..., xp=np) all gain xp parameter.
  • Changed behaviour: NaN-noise-map runtime guard now NumPy-side only.

See full details below.

Test Plan

  • 3 new unit tests in test_autoarray/dataset/interferometer/test_simulator_use_jax.py — constructor wiring, signature symmetry.
  • 834 PyAutoArray tests pass (no regressions).
  • Cross-xp parity script at autolens_workspace_test/scripts/interferometer/simulator_use_jax_parity.py confirms eager NumPy and JAX paths produce identical visibilities to atol=1e-8 (200-visibility noise-free Sersic + Isothermal lens).
Full API Changes

Added

  • aa.SimulatorInterferometer(use_jax: bool = False) constructor parameter.
  • aa.SimulatorInterferometer._xp property — returns jax.numpy if self.use_jax, else numpy.

Changed signature

  • aa.SimulatorInterferometer.via_image_from(image, xp=None) — fixed signature asymmetry vs aa.SimulatorImaging.via_image_from (which already had xp=). xp=None defaults to self._xp.
  • preprocess.gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1, xp=np).
  • preprocess.data_with_gaussian_noise_added(data, sigma, seed=-1, xp=np).
  • preprocess.data_with_complex_gaussian_noise_added(data, sigma, seed=-1, xp=np).

Changed behaviour

  • aa.SimulatorInterferometer.via_image_from: NaN-noise-map runtime guard now NumPy-side only (Python if <tracer>: triggers TracerBoolConversionError under JAX).
  • preprocess.gaussian_noise_via_shape_and_sigma_from: JAX path uses jax.random.PRNGKey(seed) + jax.random.normal (scaled by sigma). seed=-1 derives a time-based key.

Migration

  • No breaking changes. Existing callers without xp= continue to work — use_jax defaults to False.

Known limitations

  • @jax.jit wrap of via_image_from is currently blocked by the same Array2D.native jit-incompatibility flagged in the SimulatorImaging PR. Eager JAX usage works today.

Out of scope (separate PRs)

  • xp=np + jnp-backed-grid mismatch ValueError in AbstractMaker.__init__ (Phase 2 PR 4).
  • Array2D.native jit-safety refactor (unblocks @jax.jit for all simulators).

🤖 Generated with Claude Code

Adds use_jax constructor flag to aa.SimulatorInterferometer and threads xp
through via_image_from. When use_jax=True, complex Gaussian noise routes
through jax.random.PRNGKey + jax.random.normal instead of numpy's RNG.

Also fixes the pre-existing signature asymmetry: aa.SimulatorImaging.via_image_from
accepts xp=None but aa.SimulatorInterferometer.via_image_from did not. Both now
match: via_image_from(image, xp=None).

preprocess.gaussian_noise_via_shape_and_sigma_from, data_with_gaussian_noise_added,
and data_with_complex_gaussian_noise_added all gain xp=np parameter, with JAX
path routing through jax.random.

The NaN-noise-map runtime guard is now NumPy-side only (same pattern as
SimulatorImaging PR — Python `if <tracer>:` triggers TracerBoolConversionError
under JAX).

Eager JAX path works. @jax.jit wrap of via_image_from is currently blocked by
the same pre-existing autoarray .native limitation flagged in the SimulatorImaging
PR (slim/native reshape uses indexed assignment, not jit-traceable). Separate
task needed for that.

Part of Phase 2 PR 3 of z_features/jax_user_intro.md. Companion PRs add
SimulatorInterferometer subclass overrides in PyAutoLens and PyAutoGalaxy.

Design doc: admin_jammy/notes/jax_interface.md
Issue: PyAutoArray#334

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 24, 2026
@Jammy2211 Jammy2211 merged commit 32abd78 into main May 24, 2026
6 checks passed
@Jammy2211 Jammy2211 deleted the feature/simulator-interferometer-use-jax branch May 24, 2026 16:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant